{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transformers installation\n",
    "! pip install transformers datasets evaluate accelerate\n",
    "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
    "# ! pip install git+https://github.com/huggingface/transformers.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image captioning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Image captioning is the task of predicting a caption for a given image. Common real world applications of it include\n",
    "aiding visually impaired people that can help them navigate through different situations. Therefore, image captioning\n",
    "helps to improve content accessibility for people by describing images to them.\n",
    "\n",
    "This guide will show you how to:\n",
    "\n",
    "* Fine-tune an image captioning model.\n",
    "* Use the fine-tuned model for inference.\n",
    "\n",
    "Before you begin, make sure you have all the necessary libraries installed:\n",
    "\n",
    "```bash\n",
    "pip install transformers datasets evaluate -q\n",
    "pip install jiwer -q\n",
    "```\n",
    "\n",
    "We encourage you to log in to your Hugging Face account so you can upload and share your model with the community. When prompted, enter your token to log in:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the Pokémon BLIP captions dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the 🤗 Dataset library to load a dataset that consists of {image-caption} pairs. To create your own image captioning dataset\n",
    "in PyTorch, you can follow [this notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/GIT/Fine_tune_GIT_on_an_image_captioning_dataset.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "ds = load_dataset(\"lambdalabs/pokemon-blip-captions\")\n",
    "ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "DatasetDict({\n",
    "    train: Dataset({\n",
    "        features: ['image', 'text'],\n",
    "        num_rows: 833\n",
    "    })\n",
    "})\n",
    "```\n",
    "\n",
    "The dataset has two features, `image` and `text`.\n",
    "\n",
    "<Tip>\n",
    "\n",
    "Many image captioning datasets contain multiple captions per image. In those cases, a common strategy is to randomly sample a caption amongst the available ones during training.\n",
    "\n",
    "</Tip>\n",
    "\n",
    "Split the dataset's train split into a train and test set with the [train_test_split](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.train_test_split) method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = ds[\"train\"].train_test_split(test_size=0.1)\n",
    "train_ds = ds[\"train\"]\n",
    "test_ds = ds[\"test\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's visualize a couple of samples from the training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from textwrap import wrap\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def plot_images(images, captions):\n",
    "    plt.figure(figsize=(20, 20))\n",
    "    for i in range(len(images)):\n",
    "        ax = plt.subplot(1, len(images), i + 1)\n",
    "        caption = captions[i]\n",
    "        caption = \"\\n\".join(wrap(caption, 12))\n",
    "        plt.title(caption)\n",
    "        plt.imshow(images[i])\n",
    "        plt.axis(\"off\")\n",
    "\n",
    "\n",
    "sample_images_to_visualize = [np.array(train_ds[i][\"image\"]) for i in range(5)]\n",
    "sample_captions = [train_ds[i][\"text\"] for i in range(5)]\n",
    "plot_images(sample_images_to_visualize, sample_captions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_training_images_image_cap.png\" alt=\"Sample training images\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since the dataset has two modalities (image and text), the pre-processing pipeline will preprocess images and the captions.\n",
    "\n",
    "To do so, load the processor class associated with the model you are about to fine-tune."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor\n",
    "\n",
    "checkpoint = \"microsoft/git-base\"\n",
    "processor = AutoProcessor.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The processor will internally pre-process the image (which includes resizing, and pixel scaling) and tokenize the caption."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transforms(example_batch):\n",
    "    images = [x for x in example_batch[\"image\"]]\n",
    "    captions = [x for x in example_batch[\"text\"]]\n",
    "    inputs = processor(images=images, text=captions, padding=\"max_length\")\n",
    "    inputs.update({\"labels\": inputs[\"input_ids\"]})\n",
    "    return inputs\n",
    "\n",
    "\n",
    "train_ds.set_transform(transforms)\n",
    "test_ds.set_transform(transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the dataset ready, you can now set up the model for fine-tuning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load a base model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the [\"microsoft/git-base\"](https://huggingface.co/microsoft/git-base) into a [`AutoModelForCausalLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM) object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Image captioning models are typically evaluated with the [Rouge Score](https://huggingface.co/spaces/evaluate-metric/rouge) or [Word Error Rate](https://huggingface.co/spaces/evaluate-metric/wer). For this guide, you will use the Word Error Rate (WER).\n",
    "\n",
    "We use the 🤗 Evaluate library to do so. For potential limitations and other gotchas of the WER, refer to [this guide](https://huggingface.co/spaces/evaluate-metric/wer)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from evaluate import load\n",
    "import torch\n",
    "\n",
    "wer = load(\"wer\")\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predicted = logits.argmax(-1)\n",
    "    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)\n",
    "    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)\n",
    "    wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)\n",
    "    return {\"wer_score\": wer_score}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, you are ready to start fine-tuning the model. You will use the 🤗 [Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer) for this.\n",
    "\n",
    "First, define the training arguments using [TrainingArguments](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "model_name = checkpoint.split(\"/\")[1]\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"{model_name}-pokemon\",\n",
    "    learning_rate=5e-5,\n",
    "    num_train_epochs=50,\n",
    "    fp16=True,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    gradient_accumulation_steps=2,\n",
    "    save_total_limit=3,\n",
    "    eval_strategy=\"steps\",\n",
    "    eval_steps=50,\n",
    "    save_strategy=\"steps\",\n",
    "    save_steps=50,\n",
    "    logging_steps=50,\n",
    "    remove_unused_columns=False,\n",
    "    push_to_hub=True,\n",
    "    label_names=[\"labels\"],\n",
    "    load_best_model_at_end=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then pass them along with the datasets and the model to 🤗 Trainer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_ds,\n",
    "    eval_dataset=test_ds,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To start training, simply call [train()](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.train) on the [Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer) object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should see the training loss drop smoothly as training progresses.\n",
    "\n",
    "Once training is completed, share your model to the Hub with the [push_to_hub()](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.push_to_hub) method so everyone can use your model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Take a sample image from `test_ds` to test the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "\n",
    "url = \"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png\"\n",
    "image = Image.open(requests.get(url, stream=True).raw)\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/test_image_image_cap.png\" alt=\"Test image\"/>\n",
    "</div>\n",
    "\n",
    "Prepare image for the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from accelerate import Accelerator\n",
    "\n",
    "device = Accelerator().device\n",
    "inputs = processor(images=image, return_tensors=\"pt\").to(device)\n",
    "pixel_values = inputs.pixel_values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Call `generate` and decode the predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_ids = model.generate(pixel_values=pixel_values, max_length=50)\n",
    "generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
    "print(generated_caption)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "a drawing of a pink and blue pokemon\n",
    "```\n",
    "\n",
    "Looks like the fine-tuned model generated a pretty good caption!"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
